﻿// Accord Statistics Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2017
// cesarsouza at gmail.com
//
//    This library is free software; you can redistribute it and/or
//    modify it under the terms of the GNU Lesser General Public
//    License as published by the Free Software Foundation; either
//    version 2.1 of the License, or (at your option) any later version.
//
//    This library is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//    Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public
//    License along with this library; if not, write to the Free Software
//    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
//

namespace Accord.Statistics.Filters
{
    using Accord.MachineLearning;
    using Accord.Math;
    using System;
    using System.Collections.Generic;
    using System.Linq.Expressions;
    using System.Data;
    using Accord.Compat;

    /// <summary>
    ///   Value discretization preprocessing filter.
    /// </summary>
    /// 
    /// <remarks>
    ///   This filter converts ranges of values into a different representation
    ///   according to a set of rules. Please see the examples below to see how
    ///   this filter can be used in practice.</remarks>
    /// 
    /// <example>
    /// <para>
    ///   The discretization filter can be used to convert any range of values into another representation. For example,
    ///   let's say we have a dataset where a column represents percentages using floating point numbers, but we would 
    ///   like to discretize those numbers into more descriptive labels:</para>
    ///   <code source="Unit Tests\Accord.Tests.Statistics\Filters\DiscretizationFilterTest.cs" region="doc_percentage" />
    ///   
    /// <para>
    ///   The discretization filter can also be used to process <c>DataTable</c> like the <see cref="Codification"/> filter. 
    ///   It can also be used in combination with <see cref="Codification"/> to process datasets for classification, as shown
    ///   in the example below:</para>
    ///   <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\C45LearningTest.cs" region="doc_missing_thresholds" />
    /// </example>
    /// 
    /// <seealso cref="Codification"/>
    /// 
    [Serializable]
    public partial class Discretization<TInput, TOutput> : BaseFilter<Discretization<TInput, TOutput>.Options, Discretization<TInput, TOutput>>, 
        IAutoConfigurableFilter,
        ITransform<TInput[], TOutput[]>, 
        IUnsupervisedLearning<Discretization<TInput, TOutput>, TInput[], TOutput[]>
    {

        /// <summary>
        /// Gets the number of outputs generated by the model.
        /// </summary>
        /// 
        /// <value>The number of outputs.</value>
        /// 
        public int NumberOfOutputs
        {
            get
            {
                int total = 0;
                foreach (var col in Columns)
                    total += col.NumberOfSymbols;
                return total;
            }
            set { throw new InvalidOperationException("This property is read only."); }
        }

        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization()
        {
        }

#if !NETSTANDARD1_4
        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization(DataTable data)
            : this()
        {
            this.Learn(data);
        }
#endif

        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization(params string[] columns)
            : this()
        {
            for (int i = 0; i < columns.Length; i++)
                Columns.Add(new Options(columns[i]));
        }

        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization(string[] columns, object[][] data)
            : this()
        {
            for (int i = 0; i < columns.Length; i++)
                Columns.Add(new Options(columns[i])).Learn(data.GetColumn(i));
        }

        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization(string[] columns, TInput[][] data)
            : this()
        {
            for (int i = 0; i < columns.Length; i++)
                Columns.Add(new Options(columns[i])).Learn(data.GetColumn(i));
        }

#if !NETSTANDARD1_4
        /// <summary>
        ///   Creates a new Discretization Filter.
        /// </summary>
        /// 
        public Discretization(DataTable data, params string[] columns)
            : this()
        {
            for (int i = 0; i < columns.Length; i++)
                Columns.Add(new Options(columns[i]).Learn(data));
        }
#endif

        /// <summary>
        ///   Translates a value of a given variable into its codeword representation.
        /// </summary>
        /// 
        /// <param name="columnName">The name of the variable's data column.</param>
        /// <param name="value">The value to be translated.</param>
        /// 
        public TOutput Transform(string columnName, TInput value)
        {
            return Columns[columnName].Transform(value);
        }

        /// <summary>
        ///   Translates an array of values into their codeword representation, 
        ///   assuming values are given in original order of columns.
        /// </summary>
        /// 
        /// <param name="data">The values to be translated.</param>
        /// 
        public TOutput[][] Transform(object[][] data)
        {
            var result = new TOutput[data.Length][];
            for (int i = 0; i < data.Length; i++)
                result[i] = Transform(data[i]);
            return result;
        }

        /// <summary>
        ///   Translates an array of values into their codeword representation, 
        ///   assuming values are given in original order of columns.
        /// </summary>
        /// 
        /// <param name="data">The values to be translated.</param>
        /// 
        public TOutput[] Transform(object[] data)
        {
            if (data.Length > this.Columns.Count)
            {
                throw new ArgumentException("The array contains more values"
                    + " than the number of known columns.", "data");
            }

            TOutput[] result = new TOutput[data.Length];
            for (int i = 0; i < data.Length; i++)
            {
                if (data[i] is TOutput)
                    result[i] = (TOutput)data[i];
                else if (data[i] is TInput)
                    result[i] = this.Columns[i].Transform((TInput)data[i]);
            }
            return result;
        }

        /// <summary>
        ///   Translates an array of values into their codeword representation, 
        ///   assuming values are given in original order of columns.
        /// </summary>
        /// 
        /// <param name="data">The values to be translated.</param>
        /// 
        public TOutput[] Transform(params TInput[] data)
        {
            if (this.Columns.Count == 1)
                return this.Columns[0].Transform(data);

            if (data.Length > this.Columns.Count)
            {
                throw new ArgumentException("The array contains more values"
                    + " than the number of known columns.", "data");
            }

            TOutput[] result = new TOutput[data.Length];
            for (int i = 0; i < data.Length; i++)
                result[i] = this.Columns[i].Transform(data[i]);
            return result;
        }

#if !NETSTANDARD1_4
        /// <summary>
        ///   Translates an array of values into their codeword representation, 
        ///   assuming values are given in original order of columns.
        /// </summary>
        /// 
        /// <param name="row">A <see cref="DataRow"/> containing the values to be translated.</param>
        /// <param name="columnNames">The columns of the <paramref name="row"/> containing the
        /// values to be translated.</param>
        /// 
        public TOutput[] Transform(DataRow row, params string[] columnNames)
        {
            var result = new TOutput[columnNames.Length];
            for (int i = 0; i < columnNames.Length; i++)
                result[i] = this.Columns[columnNames[i]].Transform(row);
            return result;
        }
#endif

        /// <summary>
        ///   Translates a value of the given variables
        ///   into their codeword representation.
        /// </summary>
        /// 
        /// <param name="columnNames">The names of the variable's data column.</param>
        /// <param name="values">The values to be translated.</param>
        /// 
        public TOutput[] Transform(string[] columnNames, TInput[] values)
        {
            if (columnNames.Length != values.Length)
            {
                throw new ArgumentException("The number of column names"
                    + " and the number of values must match.", "values");
            }

            var result = new TOutput[values.Length];
            for (int i = 0; i < columnNames.Length; i++)
                result[i] = this.Columns[columnNames[i]].Transform(values[i]);
            return result;
        }

        /// <summary>
        ///   Translates a value of the given variables
        ///   into their codeword representation.
        /// </summary>
        /// 
        /// <param name="columnName">The variable name.</param>
        /// <param name="values">The values to be translated.</param>
        /// 
        public TOutput[] Transform(string columnName, TInput[] values)
        {
            return this.Columns[columnName].Transform(values);
        }

        /// <summary>
        ///   Translates a value of the given variables
        ///   into their codeword representation.
        /// </summary>
        /// 
        /// <param name="columnName">The variable name.</param>
        /// <param name="values">The values to be translated.</param>
        /// 
        public TOutput[][] Transform(string columnName, TInput[][] values)
        {
            return values.Apply(x => this.Columns[columnName].Transform(x));
        }

        /// <summary>
        ///   Translates a value of the given variables
        ///   into their codeword representation.
        /// </summary>
        /// 
        /// <param name="input">The values to be translated.</param>
        /// 
        public TOutput[][] Transform(TInput[][] input)
        {
            return Transform(input, new TOutput[input.Length][]);
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public TOutput[][] Transform(TInput[][] input, TOutput[][] result)
        {
            for (int j = 0; j < input.Length; j++)
            {
                TInput[] x = input[j];
                TOutput[] r = result[j];

                for (int i = 0; i < Columns.Count; i++)
                    r[i] = Columns[i].Transform(x[i]);
            }

            return result;
        }


#if !NETSTANDARD1_4
        /// <summary>
        ///   Processes the current filter.
        /// </summary>
        /// 
        protected override DataTable ProcessFilter(DataTable data)
        {
            // Copy only the schema (Clone)
            DataTable result = data.Clone();

            // For each column having a mapping
            foreach (Options options in Columns)
            {
                if (!result.Columns.Contains(options.ColumnName))
                    continue;

                // Change its type to the desired output type
                result.Columns[options.ColumnName].MaxLength = -1;
                result.Columns[options.ColumnName].DataType = typeof(TOutput);
            }


            // Now for each row on the original table
            foreach (DataRow inputRow in data.Rows)
            {
                // We'll import to the result table
                DataRow resultRow = result.NewRow();

                // For each column in original table
                foreach (DataColumn column in data.Columns)
                {
                    string name = column.ColumnName;
                    Object obj = inputRow[name];

                    if (obj is DBNull)
                    {
                        resultRow[name] = DBNull.Value;
                        continue;
                    }

                    // If the column has a mapping
                    if (Columns.Contains(name))
                    {
                        var options = Columns[name];

                        TInput value = (TInput)System.Convert.ChangeType(obj, typeof(TInput));

                        resultRow[name] = options.Transform(value);
                    }
                    else
                    {
                        // The column does not have a mapping
                        //  so we'll just copy the value over
                        resultRow[name] = inputRow[name];
                    }
                }

                // Finally, add the row into the result table
                result.Rows.Add(resultRow);
            }

            return result;
        }
#endif

        /// <summary>
        /// Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// <returns>A model that has learned how to produce suitable outputs
        /// given the input data <paramref name="x" />.</returns>
        public Discretization<TInput, TOutput> Learn(TInput[] x, double[] weights = null)
        {
            if (this.Columns.Count == 0)
                this.Columns.Add(new Options("0"));
            if (this.Columns.Count != 1)
                throw new Exception("There are more predefined columns than columns in the data.");

            Columns[0].Learn(x, weights);

            return this;
        }

        /// <summary>
        /// Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// <returns>A model that has learned how to produce suitable outputs
        /// given the input data <paramref name="x" />.</returns>
        public Discretization<TInput, TOutput> Learn(TInput[][] x, double[] weights = null)
        {
            for (int i = this.Columns.Count; i < x.Columns(); i++)
                this.Columns.Add(new Options(i.ToString()));
            if (this.Columns.Count != x.Columns())
                throw new Exception("There are more predefined columns than columns in the data.");

            for (int i = 0; i < Columns.Count; i++)
                Columns[i].Learn(x.GetColumn(i), weights);

            return this;
        }

#if !NETSTANDARD1_4
        /// <summary>
        /// Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// <returns>A model that has learned how to produce suitable outputs
        /// given the input data <paramref name="x" />.</returns>
        public Discretization<TInput, TOutput> Learn(DataTable x, double[] weights = null)
        {
            if (weights != null)
                throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");

            foreach (DataColumn col in x.Columns)
            {
                if (!this.Columns.Contains(col.ColumnName))
                    if (col.DataType == typeof(TInput))
                        Columns.Add(new Options(col.ColumnName));
            }

            foreach (DataColumn col in x.Columns)
            {
                if (col.DataType == typeof(TInput))
                    this.Columns[col.ColumnName].Learn(x, weights);
            }

            return this;
        }

        /// <summary>
        ///   Auto detects the filter options by analyzing a given <see cref="System.Data.DataTable"/>.
        /// </summary> 
        /// 
        public void Detect(DataTable data)
        {
            foreach (DataColumn column in data.Columns)
            {
                // If the column has a continuous numeric type
                if (column.DataType == typeof(Double) ||
                    column.DataType == typeof(Decimal))
                {
                    // Add the column to the processing options
                    if (!Columns.Contains(column.ColumnName))
                        Columns.Add(new Options(column.ColumnName));
                }
            }
        }
#endif

        /// <summary>
        /// Adds the specified matching rule to a column.
        /// </summary>
        /// 
        /// <param name="columnName">Name of the column.</param>
        /// <param name="rule">The rule.</param>
        /// <param name="output">The output that should be generated whenever a data sample matches with the rule.</param>
        /// 
        public void Add(string columnName, Expression<Func<TInput, bool>> rule, TOutput output)
        {
            if (!this.Columns.Contains(columnName))
                this.Columns.Add(new Options(columnName));

            this[columnName].Mapping.Add(rule, x => output);
        }

        /// <summary>
        /// Adds the specified matching rule to a column.
        /// </summary>
        /// 
        /// <param name="columnName">Name of the column.</param>
        /// <param name="rule">The rule.</param>
        /// <param name="output">The output that should be generated whenever a data sample matches with the rule.</param>
        /// 
        public void Add(string columnName, Expression<Func<TInput, bool>> rule, Expression<Func<TInput, TOutput>> output)
        {
            if (!this.Columns.Contains(columnName))
                this.Columns.Add(new Options(columnName));

            this[columnName].Mapping.Add(rule, output);
        }
    }
}
